{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import theano.tensor as tt\n",
    "from keras.layers.recurrent import GRU\n",
    "from keras.layers.core import Dense, MaskedLayer, Layer, Merge\n",
    "from keras.models import Graph\n",
    "from keras.utils.theano_utils import shared_zeros"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class SoftSequentialAttentionLayer(MaskedLayer):\n",
    "    \n",
    "    def __init__(self, memmory_dim, driver_dim, inner_dim=128, init='glorot_uniform', inner_activation='relu'):\n",
    "        super(SoftSequentialAttentionLayer, self).__init__()\n",
    "        self.init = initializations.get(init)\n",
    "        self.W_m = self.init((memory_dim, inner_dim))\n",
    "        self.W_d = self.init((driver_dim, inner_dim))\n",
    "        self.W_a = self.init((inner_dim, 1))\n",
    "        self.inner_activation = activations.get(inner_activation)\n",
    "        self.b_inner = shared_zeros(inner_dim)\n",
    "        self.b_out = shared_zeros(1)\n",
    "    \n",
    "    def set_previous(self, *previous_layers):\n",
    "        type_name = self.__class__.__name__\n",
    "        if len(previous_layers) != 2:\n",
    "            raise ValueError(\"{}.set_previous expects 2 input layers, got {}\".format(\n",
    "                type_name, previous_layers))\n",
    "        sequential_memory, attention_driver = previous_layers\n",
    "        if not sequential_memory.return_sequences:\n",
    "            raise ValueError(\"The first input of {} should be a recurrent layer with\"\n",
    "                             \" return_sequences=True\".format(type_name))\n",
    "        self.sequential_memory = sequential_memory\n",
    "        self.attention_driver = attention_driver\n",
    "        \n",
    "    def get_input(self, train=False):\n",
    "        return [self.sequential_memory.get_output(train=train),\n",
    "                self.attention_driver.get_output(train=train)]\n",
    "        \n",
    "    def get_output(self, train=False):\n",
    "        sequential_memory, attention_driver = self.get_input(train=train)\n",
    "        # sequential_memory shape: (nb_samples, time (padded with zeros), input_dim)\n",
    "        # attentin_driver shape: (nb_samples, input_dim)\n",
    "        # new shape: (time, nb_samples, input_dim) -> because theano.scan iterates over main dimension\n",
    "        padded_mask = self.get_padded_shuffled_mask(train, sequential_memory, pad=1)\n",
    "        sequential_memory = sequential_memory.dimshuffle((1, 0, 2))\n",
    "        h = self.inner_activation(tt.dot(sequential_memory, self.W_m)\n",
    "                                  + tt.dot(driver, self.W_d)\n",
    "                                  + self.b_inner)\n",
    "        a = tt.exp(tt.dot(h, self.W_a) + self.b_out)\n",
    "        \n",
    "        \n",
    "        output = None  #XXX: TODO\n",
    "        return output\n",
    "    \n",
    "    def _variable_length_softmax_step(self, a_t, sum_t):\n",
    "        return )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class CustomGraph(Graph):\n",
    "\n",
    "    def add_node(self, layer, name, input=None, inputs=[], merge_mode='concat', create_output=False):\n",
    "        if hasattr(layer, 'set_name'):\n",
    "            layer.set_name(name)\n",
    "        if name in self.namespace:\n",
    "            raise Exception('Duplicate node identifier: ' + name)\n",
    "        if input:\n",
    "            if input not in self.namespace:\n",
    "                raise Exception('Unknown node/input identifier: ' + input)\n",
    "            if input in self.nodes:\n",
    "                layer.set_previous(self.nodes[input])\n",
    "            elif input in self.inputs:\n",
    "                layer.set_previous(self.inputs[input])\n",
    "        if inputs:\n",
    "            to_merge = []\n",
    "            for n in inputs:\n",
    "                if n in self.nodes:\n",
    "                    to_merge.append(self.nodes[n])\n",
    "                elif n in self.inputs:\n",
    "                    to_merge.append(self.inputs[n])\n",
    "                else:\n",
    "                    raise Exception('Unknown identifier: ' + n)\n",
    "            # XXX: here is the change\n",
    "            if merge_mode == 'distinct':\n",
    "                layer.set_previous(*to_merge)\n",
    "            else:\n",
    "                merge = Merge(to_merge, mode=merge_mode)\n",
    "                layer.set_previous(merge)\n",
    "\n",
    "        self.namespace.add(name)\n",
    "        self.nodes[name] = layer\n",
    "        self.node_config.append({'name': name,\n",
    "                                 'input': input,\n",
    "                                 'inputs': inputs,\n",
    "                                 'merge_mode': merge_mode})\n",
    "        layer.init_updates()\n",
    "        params, regularizers, constraints, updates = layer.get_params()\n",
    "        self.params += params\n",
    "        self.regularizers += regularizers\n",
    "        self.constraints += constraints\n",
    "        self.updates += updates\n",
    "\n",
    "        if create_output:\n",
    "            self.add_output(name, input=name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "graph = CustomGraph()\n",
    "graph.add_input(name='context_sequences', ndim=3)\n",
    "graph.add_node(GRU(32, return_sequences=True), name='dense1', input='context_sequences')\n",
    "graph.add_node(Dense(32, 4), name='dense2', input='context_sequences')\n",
    "graph.add_node(SoftSequentialAttentionLayer(),\n",
    "               name='attention', inputs=['dense1', 'dense2'],\n",
    "               merge_mode='distinct')\n",
    "graph.add_output(name='output1', input='dense2')\n",
    "graph.add_output(name='output2', input='attention')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'attention': <__main__.SoftSequentialAttentionLayer at 0x10873d630>,\n",
       " 'dense1': <keras.layers.recurrent.GRU at 0x1085caeb8>,\n",
       " 'dense2': <keras.layers.core.Dense at 0x10873f438>}"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph.nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'attention', 'context_sequences', 'dense1', 'dense2'}"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "graph.namespace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "x = np.arange(3 * 4 * 5).reshape(5, 3, 4)\n",
    "a = np.arange(4 * 2).reshape(4, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, 3, 2)"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.dot(x, a).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.4.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
